Skip to content

Conversation

@mitruska
Copy link
Contributor

@mitruska mitruska commented Sep 23, 2025

Details:

This transformation is for compile time and is not enabled by default, it should be enabled in each plugin with MOE plugin support.
Example registration of the fusion transformation for CPU plugin: 41145cf

  • Fuse vectorized MatMul experts into MOE for 3GEMMs and 2GEMMs pattern:
class ov::pass::VectorizedExpertsFusion : public ov::pass::GraphRewrite {
public:
    OPENVINO_GRAPH_REWRITE_RTTI("VectorizedExpertsFusion");
    VectorizedExpertsFusion() {
        add_matcher<ov::pass::FuseVectorizedMOE2GEMM>();
        add_matcher<ov::pass::FuseVectorizedMOE3GEMM>();
    }
};
  • Add internal MOE op

MOE internal op spec PR:

Preliminary requirements (offline transformations):

Tickets:

  • transformation (and fusion details): 173663, op: 171913

@mitruska mitruska self-assigned this Sep 23, 2025
@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations category: CPP API OpenVINO CPP API bindings labels Sep 23, 2025
@github-actions github-actions bot added category: CPU OpenVINO CPU plugin and removed category: CPP API OpenVINO CPP API bindings labels Sep 24, 2025
@mitruska mitruska changed the title [Transformations][MOE] Fuse vectorized MatMul experts into MOE [Transformations][MOE] Add MOE internal op and fuse vectorized MatMul experts into MOE Sep 24, 2025
@mitruska mitruska marked this pull request as ready for review September 24, 2025 08:50
@mitruska mitruska requested review from a team as code owners September 24, 2025 08:50
@mitruska mitruska requested review from CuriousPanCake and removed request for a team September 24, 2025 08:50
@maxnick maxnick self-assigned this Sep 24, 2025
OV_OP_SCOPE(internal_MOE_validate_and_infer_types);
// TODO: Add inputs validation

set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also use shape of weights for dimension size deduction if some of them is unknown in input hidden_state

@yeonbok
Copy link
Contributor

yeonbok commented Oct 8, 2025

When you merge this transform, please disable the transform in GPU plugin until we support the moe subgraph, to prevent crash using the default packages.
Once we support, we'll turn on it.

/// (input to final multiplication)
/// 2: router_topk_output_indices - [..., topk] indices of selected top-k experts
/// 3: w0_weight - expert weights for first projection, shape [num_experts, inter_size, hidden_size] or
/// [num_experts, hidden_size, 2 * inter_size] if fused
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think [num_experts, hidden_size, 2 * inter_size] <= this will be transposed to [num_experts, 2*inter_size, hidden_size].
If three is the case when the weights are not transposed, we'll need to have a flag whether the weight is transposed or not.
Or, If we can assume that the fused MoE has the weight transposed always, it would be best.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, it will be adjusted to reflect MatMul(transpose_a=False, transpose_b=True),
related PR:

@mitruska
Copy link
Contributor Author

When you merge this transform, please disable the transform in GPU plugin until we support the moe subgraph, to prevent crash using the default packages. Once we support, we'll turn on it.

This transformation is not enabled by default, it should be enabled in each plugin with MOE plugin support.
Example registration of the fusion transformation for CPU plugin: 41145cf

@mitruska mitruska added this to the 2025.4 milestone Oct 16, 2025
@mlukasze mlukasze requested review from rkazants and yeonbok October 16, 2025 11:52
@mryzhov mryzhov self-assigned this Oct 16, 2025
Merged via the queue into openvinotoolkit:master with commit a364bf5 Oct 16, 2025
207 checks passed
github-merge-queue bot pushed a commit that referenced this pull request Nov 5, 2025
### Details:
In this PR we introduce yet another operation "GatherMatmu", which
essentially does gemv operations over the current tokens and the active
experts.
As the first step, we perform gemv operation using the
dnnl::inner_product. But obviously this solution is suboptimal, as it
doesn't give a fine grain control over parallelization, and in the case
of many tokens being processed by a specific expert (prefill), having
gemm operation may be more optimal as the tokens may be batched and we
can do SIMD level parallelization by tokens as well.
Also this PR contains all the essential transformations that allow to
enable a few common MoE patterns.

MoE pattern matcher is based on
#32183

Related oneDNN fork PR:
openvinotoolkit/oneDNN#292

### Tickets:
 - CVS-171910

---------

Co-authored-by: Vladislav Golubev <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph) category: CPU OpenVINO CPU plugin category: transformations OpenVINO Runtime library - Transformations

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants